#!/usr/bin/env python

import os
from sailing_robot.navigation import Navigation
import rospy, math, time, collections
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.image as mpimg

from std_msgs.msg import String, Float32, Float64
from sensor_msgs.msg import NavSatFix
import numpy as np

#import smopy

# color palette definition (V2 from  https://matplotlib.org/users/dflt_style_changes.html#colors-color-cycles-and-color-maps)
C = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
     '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
     '#bcbd22', '#17becf']


class Debugging_2D_matplot():

    def __init__(self):
        rospy.init_node("debugging_2D_matplot")

        utm_zone = rospy.get_param('navigation/utm_zone')
        self.nav = Navigation(utm_zone=utm_zone)

        self.rate = rospy.Rate(rospy.get_param("config/rate"))
        self.wp_radius = rospy.get_param('wp/acceptRadius')

        # Get waypoints
        if rospy.has_param('wp/list'):
            wp_list = rospy.get_param('wp/list')
        elif rospy.has_param('wp/tasks'):
            tasks_list = rospy.get_param('wp/tasks')
            wp_list = [t['waypoint'] for t in tasks_list if 'waypoint' in t]
        else:
            rospy.logwarn("No waypoint was found!")

        wp_table = rospy.get_param('wp/table')
        wp_list = list(set(wp_list)) # print each point only once
        self.wp_array = np.array([self.nav.latlon_to_utm(wp_table[wp][0], wp_table[wp][1]) for wp in wp_list]).T # [lat, lon]

        self.origin = [self.wp_array[0].mean(), self.wp_array[1].mean()]

        # Subscribers init
        rospy.Subscriber('sailing_state', String, self.update_sailing_state)
        self.sailing_state = 'normal'

        rospy.Subscriber('heading', Float32, self.update_heading)
        self.heading = 0

        rospy.Subscriber('goal_heading', Float32, self.update_goal_heading)
        self.goal_heading = 0

        rospy.Subscriber('wind_direction_apparent', Float64, self.update_wind_direction)
        self.wind_boat = 0
        self.wind_north = 0

        self.position_history = collections.deque(maxlen = 500)
        rospy.Subscriber('position', NavSatFix, self.update_position)
        self.position = [1,1]

        self.window = [0,0,0,0]
        self.init_plot()
        self.update_plot()


    def update_sailing_state(self, msg):
        self.sailing_state = msg.data

    def update_position(self, msg):
        self.position = list(self.nav.latlon_to_utm(msg.latitude, msg.longitude))
        self.position[0] -= self.origin[0]
        self.position[1] -= self.origin[1]
        self.position_history.append(self.position)

    def update_heading(self, msg):
        self.heading = msg.data
        self.wind_north = np.radians(self.heading + self.wind_boat)

    def update_goal_heading(self, msg):
        self.goal_heading = msg.data

    def update_wind_direction(self, msg):
        self.wind_boat = msg.data
        self.wind_north = np.radians(self.heading + self.wind_boat)

    def get_bg_image(self):

        self.has_bg_img = False
        self.image = None
        side_dist = None 
        image_origin = None

        my_dir = os.path.dirname(__file__)
        image_dir = os.path.abspath(os.path.join(my_dir, '../../../utilities/map_bg_images'))

        # select the correct image if any
        for filename in os.listdir(image_dir):
            if not filename.endswith(".png"):
                continue

            data_filename = filename[:-4].split("_")
            current_img_origin = self.nav.latlon_to_utm(float(data_filename[0]), float(data_filename[1]))

            dist_to_origin = ((self.origin[0] - current_img_origin[0])**2 +
                              (self.origin[1] - current_img_origin[1])**2)**0.5
            
            if dist_to_origin < 1000:
                image_origin = current_img_origin 
                side_dist = float(data_filename[2])
                self.image = mpimg.imread(os.path.join(image_dir,filename))
                self.has_bg_img = True
                break
        
        if not self.has_bg_img:
            return
        
        minx = image_origin[0] - side_dist - self.origin[0]
        maxx = image_origin[0] + side_dist - self.origin[0]
        miny = image_origin[1] - side_dist - self.origin[1]
        maxy = image_origin[1] + side_dist - self.origin[1]
        
        image_size = (minx, maxx, miny, maxy)
        self.image_show = self.ax.imshow(self.image, extent=image_size)


    def init_plot(self):

        # recenter wp to the origin
        for i, _ in enumerate(self.wp_array[0]):
            self.wp_array[0][i] -= self.origin[0]
            self.wp_array[1][i] -= self.origin[1]

        self.maxwpdist = [0,0]
        self.maxwpdist[0] = np.max(np.abs(self.wp_array[0]))
        self.maxwpdist[1] = np.max(np.abs(self.wp_array[1]))

        self.fig = plt.figure()
        self.boatline, = plt.plot([], [], c=C[0], label="Heading")
        plt.plot([], [], c=C[7], label="Goal heading")
        plt.plot([], [], c=C[1], label="Wind direction")

        # display Waypoints
        self.wpfig = plt.scatter(self.wp_array[0], self.wp_array[1], c=C[3])

        plt.tight_layout()
        self.ax = plt.subplot(111)
    
        self.get_bg_image()


    def get_arrow(self, angle, color, reverse=False):

        figsize = self.fig.get_size_inches()
        scale_dx = figsize[0]/np.sqrt(figsize[0]**2 + figsize[1]**2)
        scale_dy = figsize[1]/np.sqrt(figsize[0]**2 + figsize[1]**2)

        if reverse:
            style = '<-'
        else:
            style = '->'

        arrow_ori = (0.88, 0.12)
        arrow_target = (arrow_ori[0] + 0.05*np.sin(angle)/scale_dx, arrow_ori[1] + 0.05*np.cos(angle)/scale_dy)
        arrow = self.ax.annotate("",
                                 xy=arrow_target, xycoords=self.ax.transAxes,
                                 xytext=arrow_ori, textcoords=self.ax.transAxes,
                                 arrowprops=dict(arrowstyle=style,
                                                 color=color,
                                                 connectionstyle="arc3"),
                                 )

        return arrow



    def animate(self, i):
        if self.position_history:
            lat, lon = np.array(self.position_history).T
            self.boatline.set_data(lat,lon)

        wind_arrow = self.get_arrow(self.wind_north, C[1], reverse=True)

        heading_arrow = self.get_arrow(np.radians(self.heading), C[0])

        goal_heading_arrow = self.get_arrow(np.radians(self.goal_heading), C[7])

        arrow_col = C[0]
        if self.sailing_state != 'normal':
            arrow_col = C[1]
        arrow_dx = 0.1*np.sin(np.radians(self.heading))
        arrow_dy = 0.1*np.cos(np.radians(self.heading))
        boat_arrow = plt.arrow(self.position[0] - arrow_dx, self.position[1] - arrow_dy,
                                    arrow_dx, arrow_dy,
                                    head_width=0.5, 
                                    head_length=1.,  
                                    fc=arrow_col, 
                                    ec=arrow_col)

        self.update_window(i)
        if self.has_bg_img:
            return self.image_show, self.boatline, wind_arrow, boat_arrow, self.wpfig, goal_heading_arrow, heading_arrow, plt.legend()
        else:
            return self.boatline, wind_arrow, boat_arrow, self.wpfig, goal_heading_arrow, heading_arrow, plt.legend()


    def update_window(self, i):
        # update window only every second
        if not i%10 == 0:
            return 

        # rounding of the window size in m 
        rounding = 10.0

        # maximum distance to origin in both direction
        distx = max(self.maxwpdist[0], abs(self.position[0])) + rounding/2
        disty = max(self.maxwpdist[1], abs(self.position[1])) + rounding/2

        distx = int(round(distx/rounding)*rounding) + 1
        disty = int(round(disty/rounding)*rounding) + 1

        # scaling to keep x and y orthonormal
        figsize = self.fig.get_size_inches()
        scale_dx = figsize[0]/np.sqrt(figsize[0]**2 + figsize[1]**2)
        scale_dy = figsize[1]/np.sqrt(figsize[0]**2 + figsize[1]**2)
        norm = min(scale_dx, scale_dy)
        scale_dx = scale_dx/norm
        scale_dy = scale_dy/norm
        
        # decide which axis is the limiting one
        if distx*scale_dy > disty*scale_dx:
            dist = distx
        else:
            dist = disty

        minx = - dist * scale_dx
        maxx = + dist * scale_dx
        miny = - dist * scale_dy
        maxy = + dist * scale_dy 
        
        if self.window == [minx, maxx, miny, maxy]:
            return 

        self.window = [minx, maxx, miny, maxy]
        self.ax.axis(self.window)

        return 



    def update_plot(self):
        line_ani = animation.FuncAnimation(self.fig, self.animate,
                                           interval=100, blit=True)
        plt.show()

if __name__ == '__main__':
    try :
        Debugging_2D_matplot()
    except rospy.ROSInterruptException:
        pass
